(ns proteins
  (:require
   [scicloj.tempfiles.api :as tempfiles]
   [tablecloth.api :as tc]
   [fastmath.core :as math]
   [fastmath.random :as random]
   [tech.v3.datatype :as dtype]
   [tech.v3.dataset :as dataset]
   [tech.v3.dataset.tensor :as dataset.tensor]
   [tech.v3.tensor :as tensor]
   [tech.v3.datatype.functional :as fun]
   [aerial.hanami.common :as hc]
   [aerial.hanami.templates :as ht]
   [scicloj.kindly.v3.kind :as kind]
   [scicloj.kindly.v3.api :as kindly]
   [scicloj.clay.v2.api :as clay]
   [libpython-clj2.python :refer [py. py.. py.-] :as py]
   [scicloj.noj.v1.vis :as vis]
   [scicloj.noj.v1.vis.python :as vis.python]
   [libpython-clj2.require :refer [require-python]]))

https://www.pymc.io/projects/docs/en/stable/learn/core_notebooks/pymc_overview.html

(require-python '[builtins :as python]
                'operator
                '[arviz :as az]
                '[arviz.style :as az.style]
                '[pandas :as pd]
                '[matplotlib.pyplot :as plt]
                '[numpy :as np]
                '[numpy.random :as np.random]
                '[pymc :as pm]
                '[Bio.PDB.PDBParser]
                '[Bio.PDB]
                '[Bio.PDB.Polypeptide]
                '[pytensor]
                '[pytensor.tensor :as pt]
                '[math])
:ok
(defn brackets [obj entry]
  (py. obj __getitem__ entry))
(def colon
  (python/slice nil nil))
(arviz.style/use "arviz-darkgrid")
nil
(defn extract-coordinates-from-pdb
  ([protein-name]
   (extract-coordinates-from-pdb protein-name :models))
  ([protein-name data-type]
   (let [filepath (str "data/" protein-name ".pdb")
         parser (Bio.PDB/PDBParser)
         structure (py. parser get_structure protein-name filepath)]
     (case data-type
       :models (-> structure
                   (->> (map
                         (fn [model]
                           (-> model
                               (->> (mapcat
                                     (fn [chain]
                                       (->> chain
                                            (filter (fn [residue]
                                                      (-> residue
                                                          (py. get_resname)
                                                          (Bio.PDB.Polypeptide/is_aa :standard true))))
                                            (map (fn [residue]
                                                   (-> residue
                                                       (brackets "CA")
                                                       (py. get_coord)
                                                       (->> (dtype/->array :float32)))))))))))))
                   (tensor/->tensor :datatype :float32))))))
(comment
  (-> "1d3z"
      extract-coordinates-from-pdb)

  (->> "1ubq"
       extract-coordinates-from-pdb))
(defn center-1d [xs]
  (fun/- xs
         (fun/mean xs)))
(defn center-columns [xyzs]
  (-> xyzs
      (tensor/map-axis center-1d 0)))
(comment
  (-> [[1 2 3]
       [4 5 9]]
      tensor/->tensor
      center-columns))
(defn center-columns [xyzs]
  (-> xyzs
      (tensor/map-axis center-1d 0)))
(defn read-data [prots
                 {:keys [data-type models rmsd?]
                  :or {data-type :models
                       models [0 1]
                       rmsd true}}]
  (case data-type
    :models (let [coords (map (fn [prot model]
                                (-> prot
                                    extract-coordinates-from-pdb
                                    (nth model)))
                              prots
                              models)
                  obs (->> coords
                           (mapv #(tensor/map-axis % center-1d 0)))
                  obs-datasets (->> obs
                                    (mapv #(-> %
                                               dataset.tensor/tensor->dataset
                                               (tc/rename-columns [:x :y :z]))))]
              {:coords coords
               :obs obs
               :obs-datasets obs-datasets})))
(let [name1 "1d3z"
      name2 "1ubq"
      models [4 0]
      samples 100]
  (->> (read-data [name1 name2]
                  {:models models})
       :obs-datasets
       (map tc/info)))
(let [name1 "1d3z"
      name2 "1ubq"
      models [4 0]
      samples 100
      {:keys [obs-datasets]} (read-data [name1 name2]
                                        {:models models})]
  (kind/hiccup
   ['(fn [{:keys [datasets]}]
       [plotly
        {:data (->> datasets
                    (mapv (fn [dataset]
                            (->> dataset
                                 (merge {:type :scatter3d
                                         :mode :lines+markers
                                         :opacity 0.6
                                         :line {:width 10}
                                         :marker {:size 4}})))))}])
    {:datasets (->> obs-datasets
                    (mapv #(update-vals % vec)))}]))
(defn ->max-distance-to-origin [centered-structure]
  (-> centered-structure
      fun/sq
      (tensor/reduce-axis fun/sum 1)
      fun/sqrt
      fun/reduce-max))
(defn ->average-structure [centered-structures]
  (-> centered-structures
      (->> (apply fun/+))
      (fun// (count centered-structures))))

trying PyTensor https://www.pymc.io/projects/docs/en/stable/learn/core_notebooks/pymc_pytensor.html

(let [x (pt/scalar :name "x")
      y (pt/scalar :name "y")
      z (operator/add x y)
      w (pt/mul z 2)
      f (pytensor/function :inputs [x y]
                           :outputs w)]
  (f :x 10
     :y 5))
30.0
(def results
  (let [name1 "1d3z"
        name2 "1ubq"
        models [4 0]
        samples 100
        {:keys [obs obs-datasets]}
        (read-data [name1 name2]
                   {:models models})
        max-distance (->max-distance-to-origin (obs 0))
        average-structure (->average-structure obs)
        shape (dtype/shape (obs 0))]
    (py/with [model (pm/Model)]
             (let [M (pm/Normal "M" :shape shape)
                   M0 (pm/Deterministic "M0"
                                        (operator/sub
                                         M
                                         (pt/mean M)))
                   t1 (pm/Normal "t1" :shape [(shape 1)])
                   t2 (pm/Normal "t2" :shape [(shape 1)])
                   u (pm/Uniform "u0" :shape [(shape 1)])
                   theta1 (-> u
                              (brackets 1)
                              (operator/mul 2)
                              (operator/mul math/PI))
                   theta2 (-> u
                              (brackets 2)
                              (operator/mul 2)
                              (operator/mul math/PI))
                   r1 (-> u
                          (brackets 0)
                          (->> (operator/sub 1))
                          pt/sqrt)
                   r2 (-> u
                          (brackets 0)
                          pt/sqrt)
                   w (-> theta2
                         (pt/cos)
                         (operator/mul r2))
                   x (-> theta1
                         (pt/sin)
                         (operator/mul r1))
                   y (-> theta1
                         (pt/cos)
                         (operator/mul r1))
                   z (-> theta2
                         (pt/sin)
                         (operator/mul r2))
                   R00 (operator/sub (operator/add (pt/sqr w)
                                                   (pt/sqr x))
                                     (operator/add (pt/sqr y)
                                                   (pt/sqr z)))
                   R11 (operator/sub (operator/add (pt/sqr w)
                                                   (pt/sqr y))
                                     (operator/add (pt/sqr x)
                                                   (pt/sqr z)))
                   R22 (operator/sub (operator/add (pt/sqr w)
                                                   (pt/sqr z))
                                     (operator/add (pt/sqr x)
                                                   (pt/sqr y)))
                   R01 (operator/mul 2
                                     (operator/sub (operator/mul x y)
                                                   (operator/mul w z)))
                   R02 (operator/mul 2
                                     (operator/add (operator/mul x z)
                                                   (operator/mul w y)))
                   R10 (operator/mul 2
                                     (operator/add (operator/mul x y)
                                                   (operator/mul w z)))
                   R12 (operator/mul 2
                                     (operator/sub (operator/mul y z)
                                                   (operator/mul w x)))
                   R20 (operator/mul 2
                                     (operator/sub (operator/mul x z)
                                                   (operator/mul w y)))
                   R21 (operator/mul 2
                                     (operator/add (operator/mul y z)
                                                   (operator/mul w x)))
                   R (pm/Deterministic "R"
                                       (pt/stack [(pt/stack [R00 R01 R02])
                                                  (pt/stack [R10 R11 R12])
                                                  (pt/stack [R20 R21 R22])]))
                   U (pm/HalfNormal "U"
                                    :sigma 0.01
                                    :shape (shape 0))
                   Q1 (pm/Normal "Q1" :shape shape)
                   Q2 (pm/Normal "Q2" :shape shape)
                   debug (pm/Deterministic "debug"
                                           (-> M0
                                               (pt/dot R)
                                               (pt/add t1)))
                   prior-predictive-samples (pm/sample_prior_predictive)]
               {:prior-predictive-samples prior-predictive-samples}))))
(-> results
    :prior-predictive-samples
    (py.- prior)
    (py.- "M0")
    np/mean)
<xarray.DataArray 'M0' ()>
array(-1.2465662e-18)
(-> results
    :prior-predictive-samples
    (py.- prior)
    (py.- "debug"))
<xarray.DataArray 'debug' (chain: 1, draw: 500, debug_dim_0: 76, debug_dim_1: 3)>
array([[[[ 0.16706071,  2.6997921 ,  3.80579563],
         [-0.66583907,  0.11061794,  2.35875333],
         [-0.33844605,  3.2677548 ,  4.74898799],
         ...,
         [-1.05468472,  1.20188624,  2.56099598],
         [-2.39869129,  2.35258191,  0.60791435],
         [-2.87583178,  0.60094867,  2.95815657]],

        [[ 1.2683437 ,  0.08588472, -1.20021449],
         [-0.15936383,  0.32123272,  0.1536795 ],
         [-0.17914647, -0.94234472, -1.39809149],
         ...,
         [ 0.3398098 , -0.72257216, -0.05470593],
         [ 0.93696573, -0.19657193,  0.44033832],
         [-1.15839049, -2.90982475, -0.5328221 ]],

        [[-4.51618532,  1.59710832, -0.17286661],
         [-3.50154325,  0.79761203, -1.26850329],
         [-1.18601627, -0.30571579, -1.08569199],
         ...,
...
         ...,
         [-0.77592396,  2.29038249, -1.87501885],
         [-1.69725733, -0.56947283, -1.97132729],
         [-0.13016795,  1.55273405, -1.51042059]],

        [[ 0.40471076, -0.97868234, -0.82151834],
         [-0.16058526,  1.07207209, -0.25623372],
         [-0.51692006,  1.04154915, -1.78367788],
         ...,
         [ 0.74325271, -1.1145294 ,  0.32821188],
         [ 1.4166422 ,  0.70435817, -0.76661303],
         [ 0.68884763, -0.5847596 , -1.04314515]],

        [[-0.27730378,  2.17870692,  0.53384238],
         [ 0.50581583,  0.94524572, -0.2116013 ],
         [-0.46537371,  1.8002356 , -0.72092065],
         ...,
         [-1.53258094,  0.16178839, -1.10599859],
         [-0.01769912,  0.63115223, -0.5830774 ],
         [-0.32403442,  0.87416766,  1.37457161]]]])
Coordinates:
  * chain        (chain) int64 0
  * draw         (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
  * debug_dim_0  (debug_dim_0) int64 0 1 2 3 4 5 6 7 ... 68 69 70 71 72 73 74 75
  * debug_dim_1  (debug_dim_1) int64 0 1 2